// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <gtest/gtest.h>
#include "tt_metal/tt_metal/common/mesh_dispatch_fixture.hpp"

#include "debug_tools_test_utils.hpp"
#include "hal_types.hpp"

namespace tt::tt_metal {

class DebugToolsMeshFixture : public MeshDispatchFixture {
   protected:
       bool watcher_previous_enabled{};

       void TearDown() override { MeshDispatchFixture::TearDown(); }

       template <typename T>
       void RunTestOnDevice(
           const std::function<void(T*, std::shared_ptr<distributed::MeshDevice>)>& run_function,
           const std::shared_ptr<distributed::MeshDevice>& mesh_device) {
           auto run_function_no_args = [this, run_function, mesh_device]() { run_function(static_cast<T*>(this), mesh_device); };
           MeshDispatchFixture::RunTestOnDevice(run_function_no_args, mesh_device);
       }
};

// A version of MeshDispatchFixture with DPrint enabled on all cores.
class DPrintMeshFixture : public DebugToolsMeshFixture {
public:
    inline static const std::string dprint_file_name = "gtest_dprint_log.txt";

    // A function to run a program, according to which dispatch mode is set.
    void RunProgram(const std::shared_ptr<distributed::MeshDevice>& mesh_device, distributed::MeshWorkload& workload) {
        // Only difference is that we need to wait for the print server to catch
        // up after running a test.
        DebugToolsMeshFixture::RunProgram(mesh_device, workload);
        MetalContext::instance().dprint_server()->await();
    }

protected:
    // Running with dprint + watcher enabled can make the code size blow up, so let's force watcher
    // disabled for DPRINT tests.
    void SetUp() override {
        // The core range (virtual) needs to be set >= the set of all cores
        // used by all tests using this fixture, so set dprint enabled for
        // all cores and all devices
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_enabled(tt::llrt::RunTimeDebugFeatureDprint, true);
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_prepend_device_core_risc(
            tt::llrt::RunTimeDebugFeatureDprint, false);
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_all_cores(
            tt::llrt::RunTimeDebugFeatureDprint, CoreType::WORKER, tt::llrt::RunTimeDebugClassWorker);
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_all_cores(
            tt::llrt::RunTimeDebugFeatureDprint, CoreType::ETH, tt::llrt::RunTimeDebugClassWorker);
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_all_chips(tt::llrt::RunTimeDebugFeatureDprint, true);
        // Send output to a file so the test can check after program is run.
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_file_name(tt::llrt::RunTimeDebugFeatureDprint, dprint_file_name);
        tt::tt_metal::MetalContext::instance().rtoptions().set_test_mode_enabled(true);
        watcher_previous_enabled = tt::tt_metal::MetalContext::instance().rtoptions().get_watcher_enabled();
        tt::tt_metal::MetalContext::instance().rtoptions().set_watcher_enabled(false);

        ExtraSetUp();

        // Parent class initializes devices and any necessary flags
        DebugToolsMeshFixture::SetUp();
    }

    void TearDown() override {
        // Parent class tears down devices
        DebugToolsMeshFixture::TearDown();
        ExtraTearDown();

        // Reset DPrint settings
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_cores(tt::llrt::RunTimeDebugFeatureDprint, {});
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_enabled(tt::llrt::RunTimeDebugFeatureDprint, false);
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_all_cores(
            tt::llrt::RunTimeDebugFeatureDprint, CoreType::WORKER, tt::llrt::RunTimeDebugClassNoneSpecified);
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_all_cores(
            tt::llrt::RunTimeDebugFeatureDprint, CoreType::ETH, tt::llrt::RunTimeDebugClassNoneSpecified);
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_all_chips(tt::llrt::RunTimeDebugFeatureDprint, false);
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_file_name(tt::llrt::RunTimeDebugFeatureDprint, "");
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_prepend_device_core_risc(
            tt::llrt::RunTimeDebugFeatureDprint, true);
        tt::tt_metal::MetalContext::instance().rtoptions().set_test_mode_enabled(false);
        tt::tt_metal::MetalContext::instance().rtoptions().set_watcher_enabled(watcher_previous_enabled);
    }

    void RunTestOnDevice(
        const std::function<void(DPrintMeshFixture*, std::shared_ptr<distributed::MeshDevice>)>& run_function,
        const std::shared_ptr<distributed::MeshDevice>& mesh_device) {
        DebugToolsMeshFixture::RunTestOnDevice(run_function, mesh_device);
        MetalContext::instance().dprint_server()->clear_log_file();
    }

    // Override this function in child classes for additional setup commands between DPRINT setup
    // and device creation.
    virtual void ExtraSetUp() {}
    virtual void ExtraTearDown() {}
};

// For usage by tests that need the dprint server devices disabled.
class DPrintDisableMeshDevicesFixture : public DPrintMeshFixture {
protected:
    void ExtraSetUp() override {
        // For this test, mute each devices using the environment variable
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_all_chips(tt::llrt::RunTimeDebugFeatureDprint, false);
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_chip_ids(tt::llrt::RunTimeDebugFeatureDprint, {});
    }
    void ExtraTearDown() override {
        MetalContext::instance().teardown(); // Teardown dprint server so we can re-init later with all devices enabled again
    }
};

class DPrintSeparateFilesFixture : public DPrintMeshFixture {
public:
    static constexpr std::array<std::string_view, 5> suffixes = {"BRISC", "NCRISC", "TRISC0", "TRISC1", "TRISC2"};
    static void check_output(std::span<const std::string> expected) {
        const auto& enabled_processors =
            tt::tt_metal::MetalContext::instance().rtoptions().get_feature_processors(tt::llrt::RunTimeDebugFeatureDprint);
        ASSERT_EQ(expected.size(), suffixes.size());
        for (size_t i = 0; i < suffixes.size(); i++) {
            if (!enabled_processors.contains(HalProgrammableCoreType::TENSIX, i)) {
                continue;
            }
            auto filename = fmt::format("generated/dprint/device-0_worker-core-0-0_{}.txt", suffixes[i]);
            EXPECT_TRUE(FilesMatchesString(filename, expected[i]));
        }
    }
protected:
    bool original_one_file_per_risc_{};
    void ExtraSetUp() override {
        // For this test, enable one file per risc
        original_one_file_per_risc_ = tt::tt_metal::MetalContext::instance().rtoptions().get_feature_one_file_per_risc(
            tt::llrt::RunTimeDebugFeatureDprint);
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_one_file_per_risc(
            tt::llrt::RunTimeDebugFeatureDprint, true);
    }
    void ExtraTearDown() override {
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_one_file_per_risc(
            tt::llrt::RunTimeDebugFeatureDprint, original_one_file_per_risc_);
    }
};

// A version of MeshDispatchFixture with watcher enabled
class MeshWatcherFixture : public DebugToolsMeshFixture {
public:
    inline static const std::string log_file_name = "generated/watcher/watcher.log";
    inline static const int interval_ms = 250;

    // A function to run a program, according to which dispatch mode is set.
    void RunProgram(
        const std::shared_ptr<distributed::MeshDevice>& mesh_device,
        distributed::MeshWorkload& workload,
        bool wait_for_dump = false) {
        // Only difference is that we need to wait for the print server to catch
        // up after running a test.
        DebugToolsMeshFixture::RunProgram(mesh_device, workload);

        // Wait for watcher to run a full dump before finishing, need to wait for dump count to
        // increase because we'll likely check in the middle of a dump.
        if (wait_for_dump) {
            int curr_count = MetalContext::instance().watcher_server()->dump_count();
            while (MetalContext::instance().watcher_server()->dump_count() < curr_count + 2) {;}
        }
    }

protected:
    int watcher_previous_interval{};
    bool watcher_previous_dump_all{};
    bool watcher_previous_append{};
    bool watcher_previous_auto_unpause{};
    bool watcher_previous_noinline{};
    bool test_mode_previous{};
    void SetUp() override {
        // Enable watcher for this test, save the previous state so we can restore it later.
        watcher_previous_enabled = tt::tt_metal::MetalContext::instance().rtoptions().get_watcher_enabled();
        watcher_previous_interval = tt::tt_metal::MetalContext::instance().rtoptions().get_watcher_interval();
        watcher_previous_dump_all = tt::tt_metal::MetalContext::instance().rtoptions().get_watcher_dump_all();
        watcher_previous_append = tt::tt_metal::MetalContext::instance().rtoptions().get_watcher_append();
        watcher_previous_auto_unpause = tt::tt_metal::MetalContext::instance().rtoptions().get_watcher_auto_unpause();
        watcher_previous_noinline = tt::tt_metal::MetalContext::instance().rtoptions().get_watcher_noinline();
        test_mode_previous = tt::tt_metal::MetalContext::instance().rtoptions().get_test_mode_enabled();
        tt::tt_metal::MetalContext::instance().rtoptions().set_watcher_enabled(true);
        tt::tt_metal::MetalContext::instance().rtoptions().set_watcher_interval(interval_ms);
        tt::tt_metal::MetalContext::instance().rtoptions().set_watcher_dump_all(false);
        tt::tt_metal::MetalContext::instance().rtoptions().set_watcher_append(false);
        tt::tt_metal::MetalContext::instance().rtoptions().set_watcher_auto_unpause(true);
        tt::tt_metal::MetalContext::instance().rtoptions().set_watcher_noinline(true);
        tt::tt_metal::MetalContext::instance().rtoptions().set_test_mode_enabled(true);
        tt::tt_metal::MetalContext::instance().rtoptions().set_watcher_noc_sanitize_linked_transaction(true);

        // Parent class initializes devices and any necessary flags
        DebugToolsMeshFixture::SetUp();
        MetalContext::instance().watcher_server()->clear_log();
    }

    void TearDown() override {
        // Parent class tears down devices
        DebugToolsMeshFixture::TearDown();

        // Reset watcher settings to their previous values
        tt::tt_metal::MetalContext::instance().rtoptions().set_watcher_interval(watcher_previous_interval);
        tt::tt_metal::MetalContext::instance().rtoptions().set_watcher_dump_all(watcher_previous_dump_all);
        tt::tt_metal::MetalContext::instance().rtoptions().set_watcher_append(watcher_previous_append);
        tt::tt_metal::MetalContext::instance().rtoptions().set_watcher_auto_unpause(watcher_previous_auto_unpause);
        tt::tt_metal::MetalContext::instance().rtoptions().set_watcher_noinline(watcher_previous_noinline);
        tt::tt_metal::MetalContext::instance().rtoptions().set_test_mode_enabled(test_mode_previous);
        tt::tt_metal::MetalContext::instance().rtoptions().set_watcher_enabled(watcher_previous_enabled);
    }

    void RunTestOnDevice(
        const std::function<void(MeshWatcherFixture*, std::shared_ptr<distributed::MeshDevice>)>& run_function,
        const std::shared_ptr<distributed::MeshDevice>& mesh_device) {
        DebugToolsMeshFixture::RunTestOnDevice(run_function, mesh_device);
        // Wait for a final watcher poll and then clear the log.
        std::this_thread::sleep_for(std::chrono::milliseconds(interval_ms));
        MetalContext::instance().watcher_server()->clear_log();
    }
};

// A version of WatcherFixture with read and write debug delays enabled
class MeshWatcherDelayFixture : public MeshWatcherFixture {
public:
    tt::llrt::TargetSelection saved_target_selection[tt::llrt::RunTimeDebugFeatureCount];

    std::map<CoreType, std::vector<CoreCoord>> delayed_cores;

    void SetUp() override {
        tt::tt_metal::MetalContext::instance().rtoptions().set_watcher_debug_delay(5000000);
        delayed_cores[CoreType::WORKER] = {{0, 0}, {1, 1}};

        // Store the previous state of the watcher features
        saved_target_selection[tt::llrt::RunTimeDebugFeatureReadDebugDelay] = tt::tt_metal::MetalContext::instance().rtoptions().get_feature_targets(tt::llrt::RunTimeDebugFeatureReadDebugDelay);
        saved_target_selection[tt::llrt::RunTimeDebugFeatureWriteDebugDelay] = tt::tt_metal::MetalContext::instance().rtoptions().get_feature_targets(tt::llrt::RunTimeDebugFeatureWriteDebugDelay);
        saved_target_selection[tt::llrt::RunTimeDebugFeatureAtomicDebugDelay] = tt::tt_metal::MetalContext::instance().rtoptions().get_feature_targets(tt::llrt::RunTimeDebugFeatureAtomicDebugDelay);

        // Enable read and write debug delay for the test core
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_enabled(tt::llrt::RunTimeDebugFeatureReadDebugDelay, true);
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_cores(tt::llrt::RunTimeDebugFeatureReadDebugDelay, delayed_cores);
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_enabled(tt::llrt::RunTimeDebugFeatureWriteDebugDelay, true);
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_cores(tt::llrt::RunTimeDebugFeatureWriteDebugDelay, delayed_cores);

        // Call parent
        MeshWatcherFixture::SetUp();
    }

    void TearDown() override {
        // Call parent
        MeshWatcherFixture::TearDown();

        // Restore
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_targets(tt::llrt::RunTimeDebugFeatureReadDebugDelay, saved_target_selection[tt::llrt::RunTimeDebugFeatureReadDebugDelay]);
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_targets(tt::llrt::RunTimeDebugFeatureWriteDebugDelay, saved_target_selection[tt::llrt::RunTimeDebugFeatureWriteDebugDelay]);
        tt::tt_metal::MetalContext::instance().rtoptions().set_feature_targets(tt::llrt::RunTimeDebugFeatureAtomicDebugDelay, saved_target_selection[tt::llrt::RunTimeDebugFeatureAtomicDebugDelay]);
    }
};

} // namespace tt::tt_metal
